{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import codecs\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from random import choice\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "from keras_bert import load_trained_model_from_checkpoint, Tokenizer\n",
    "\n",
    "from keras.utils import to_categorical\n",
    "from keras.layers import *\n",
    "from keras.callbacks import *\n",
    "from keras.models import Model\n",
    "import keras.backend as K\n",
    "from keras.optimizers import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train = pd.read_csv('raw_data/train.csv', sep='\\t')\n",
    "test = pd.read_csv('raw_data/test.csv', sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train.target.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train.stance.value_counts(dropna=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train['text'] = train.text.fillna('')\n",
    "test['text'] = test.text.fillna('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train['stance'] = train['stance'].map({'AGAINST':0, 'FAVOR':1, 'NONE':2})\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "maxlen = 256\n",
    "\n",
    "config_path = '/data/zhengheng/nlp_pretrain_models/chinese-roberta-wwm-ext-l12-h768-a12/bert_config.json'\n",
    "checkpoint_path = '/data/zhengheng/nlp_pretrain_models/chinese-roberta-wwm-ext-l12-h768-a12/bert_model.ckpt'\n",
    "dict_path = '/data/zhengheng/nlp_pretrain_models/chinese-roberta-wwm-ext-l12-h768-a12/vocab.txt'\n",
    "\n",
    "token_dict = {}\n",
    "with codecs.open(dict_path, 'r', 'utf8') as reader:\n",
    "    for line in reader:\n",
    "        token = line.strip()\n",
    "        token_dict[token] = len(token_dict)\n",
    "\n",
    "class OurTokenizer(Tokenizer):\n",
    "    def _tokenize(self, text):\n",
    "        R = []\n",
    "        for c in text:\n",
    "            if c in self._token_dict:\n",
    "                R.append(c)\n",
    "            elif self._is_space(c):\n",
    "                R.append('[unused1]') # space类用未经训练的[unused1]表示\n",
    "            else:\n",
    "                R.append('[UNK]') # 剩余的字符是[UNK]\n",
    "        return R\n",
    "\n",
    "tokenizer = OurTokenizer(token_dict)\n",
    "\n",
    "def seq_padding(X, padding=0):\n",
    "    L = [len(x) for x in X]\n",
    "    ML = max(L)\n",
    "    return np.array([\n",
    "        np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X\n",
    "    ])\n",
    "\n",
    "class data_generator:\n",
    "    def __init__(self, data, batch_size=8, shuffle=True):\n",
    "        self.data = data\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        self.steps = len(self.data) // self.batch_size\n",
    "        if len(self.data) % self.batch_size != 0:\n",
    "            self.steps += 1\n",
    "    def __len__(self):\n",
    "        return self.steps\n",
    "    def __iter__(self):\n",
    "        while True:\n",
    "            idxs = list(range(len(self.data)))\n",
    "            \n",
    "            if self.shuffle:\n",
    "                np.random.shuffle(idxs)\n",
    "            \n",
    "            X1, X2, Y = [], [], []\n",
    "            for i in idxs:\n",
    "                d = self.data[i]\n",
    "                text = d[0][:maxlen]\n",
    "                x1, x2 = tokenizer.encode(first=text)\n",
    "                y = d[1]\n",
    "                X1.append(x1)\n",
    "                X2.append(x2)\n",
    "                Y.append([y])\n",
    "                if len(X1) == self.batch_size or i == idxs[-1]:\n",
    "                    X1 = seq_padding(X1)\n",
    "                    X2 = seq_padding(X2)\n",
    "                    Y = seq_padding(Y)\n",
    "                    yield [X1, X2], Y[:, 0, :]\n",
    "                    [X1, X2, Y] = [], [], []\n",
    "                    \n",
    "def build_bert():\n",
    "    bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)\n",
    "\n",
    "    for l in bert_model.layers:\n",
    "        l.trainable = True\n",
    "\n",
    "    x1_in = Input(shape=(None,))\n",
    "    x2_in = Input(shape=(None,))\n",
    "\n",
    "    x = bert_model([x1_in, x2_in])\n",
    "    x = Lambda(lambda x: x[:, 0])(x)\n",
    "    p = Dense(3, activation='softmax')(x)\n",
    "\n",
    "    model = Model([x1_in, x2_in], p)\n",
    "    model.compile(loss='categorical_crossentropy', \n",
    "                  optimizer=Adam(1e-5),\n",
    "                  metrics=['accuracy'])\n",
    "    print(model.summary())\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "DATA_LIST = []\n",
    "for data_row in train.iloc[:].itertuples():\n",
    "    DATA_LIST.append((data_row.text, to_categorical(data_row.stance, 3)))\n",
    "DATA_LIST = np.array(DATA_LIST)\n",
    "\n",
    "DATA_LIST_TEST = []\n",
    "for data_row in test.iloc[:].itertuples():\n",
    "    DATA_LIST_TEST.append((data_row.text, to_categorical(0, 3)))\n",
    "DATA_LIST_TEST = np.array(DATA_LIST_TEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def run_cv(nfold, data, data_label, data_test):\n",
    "    \n",
    "    kf = KFold(n_splits=nfold, shuffle=True, random_state=1029).split(data)\n",
    "    train_model_pred = np.zeros((len(data), 3))\n",
    "    test_model_pred = np.zeros((len(data_test), 3))\n",
    "\n",
    "    for i, (train_fold, test_fold) in enumerate(kf):\n",
    "        X_train, X_valid, = data[train_fold, :], data[test_fold, :]\n",
    "        \n",
    "        model = build_bert()\n",
    "        early_stopping = EarlyStopping(monitor='val_acc', patience=3)\n",
    "        plateau = ReduceLROnPlateau(monitor=\"val_acc\", verbose=1, mode='max', factor=0.5, patience=2)\n",
    "        checkpoint = ModelCheckpoint('./' + str(i) + '.hdf5', monitor='val_acc', \n",
    "                                         verbose=2, save_best_only=True, mode='max', save_weights_only=True)\n",
    "        \n",
    "        train_D = data_generator(X_train, shuffle=True)\n",
    "        valid_D = data_generator(X_valid, shuffle=True)\n",
    "        test_D = data_generator(data_test, shuffle=False)\n",
    "        \n",
    "        model.fit_generator(\n",
    "            train_D.__iter__(),\n",
    "            steps_per_epoch=len(train_D),\n",
    "            epochs=5,\n",
    "            validation_data=valid_D.__iter__(),\n",
    "            validation_steps=len(valid_D),\n",
    "            callbacks=[early_stopping, plateau, checkpoint],\n",
    "        )\n",
    "        \n",
    "        train_model_pred[test_fold, :] =  model.predict_generator(valid_D.__iter__(), steps=len(valid_D), verbose=1)\n",
    "        test_model_pred += model.predict_generator(test_D.__iter__(), steps=len(test_D), verbose=1)\n",
    "        \n",
    "        del model; gc.collect()\n",
    "        K.clear_session()\n",
    "        \n",
    "    return train_model_pred, test_model_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_model_pred, test_model_pred = run_cv(5, DATA_LIST, None, DATA_LIST_TEST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test_pred = [np.argmax(x) for x in test_model_pred]\n",
    "test['y'] = test_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "submission = test[['y']]\n",
    "submission['id'] = submission.index\n",
    "submission['y'] = submission['y'].map({0:'AGAINST', 1:'FAVOR', 2:'NONE'})\n",
    "submission = submission[['id', 'y']]\n",
    "submission.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "submission.to_csv('submissions/submission_bert_ext_wwm_baseline.csv', index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!head 'submissions/submission_bert_ext_wwm_baseline.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
